"""

"""
# pylint: disable=anomalous-backslash-in-string
# pylint: disable=invalid-name
# pylint: disable=import-error
# pylint: disable=missing-function-docstring
import os
import sys
sys.path.extend(["../"]) # pylint: disable=wrong-import-position
import random
from time import time
import warnings
import pickle
import datetime
import socket

import shutil
import yaml
import numpy as np
import torch

from sklearn import linear_model

from data_utils import *
from graph_dict import *
from utils import *
from model import SAGE

warnings.filterwarnings('ignore')
os.environ["CURL_CA_BUNDLE"] = ""
DEVICE = 'cuda:2'

# %lprun -f

# Parameter settings and data loading
DATA_ROOT_FOLDER = "./dataset"
CONF_ROOT_FOLDER = "./config"
RES_ROOT_FOLDER = "./result"


#############################################
# # cora 
# CONF_NAME = "cora/SDMP/cora_SDMP_base.yml"
# TARGET_GNN_FOLDER = '../result/cora/SAGE'

# # pubmed
# CONF_NAME = "pubmed/SDMP/pubmed_SDMP_base.yml"
# TARGET_GNN_FOLDER = '../result/pubmed/SAGE'

# # reddit
# CONF_NAME = "reddit/SDMP/reddit_SDMP_base.yml"
# CONF_NAME = "reddit/SDMP/reddit_SDMP_no_feature_norm.yml"
# TARGET_GNN_FOLDER = '../result/cora/SAGE'

# OGBN-products
# CONF_NAME = "ogbn-products/SDMP/ogbn-products_SDMP_base.yml"
# TARGET_GNN_FOLDER = '../result/ogbn-products/SAGE'
CONF_NAME = "ogbn-products/SDMP/ogbn-products_SDMP_RevGNN-112.yml"
TARGET_GNN_FOLDER = '../result/ogbn-products/RevGNN-112'

# # OGBN-arxiv
# CONF_NAME = 'ogbn-arxiv/SDMP/ogbn-arxiv_SDMP_DRGAT_large_budget.yml'
# TARGET_GNN_FOLDER = './result/ogbn-arxiv/DRGAT'

# #######################################
# Initialization and hyper-parameter setting
_TIME_ZONE = 0
TIMESTAMP = time()
TIMESTAMP_FORMATTED = datetime.datetime.fromtimestamp(
    int(TIMESTAMP)+_TIME_ZONE*3600).strftime('%Y%m%d-%H%M%S')
HOST_NAME = socket.gethostname()

conf_path = os.path.join(CONF_ROOT_FOLDER, CONF_NAME)
train_conf = load_train_conf(conf_path)

DATA_FOLDER = os.path.join(DATA_ROOT_FOLDER, train_conf["name"])
# if not os.path.exists(DATA_FOLDER):
#     os.makedirs(DATA_FOLDER)

# # sys.stdout = Logger(os.path.join(RES_FOLDER, "log.txt"))

# print(train_conf)

# parameters
result_root_path = "/home/yaochen/Working/localhashgnn/result/best/ogbn-products_RevGNN-112_MLP_TEST/blackhole_2023-09-06_00:38:48"
TARGET_GNN_FOLDER = "/home/yaochen/Working/localhashgnn/result/ogbn-products/RevGNN-112"

SDMP_CONF_PATH = os.path.join(result_root_path, "dict_conf.yml")
MLP_CONF_PATH = os.path.join(result_root_path, "conf.yml")
MLP_STATE_PATH = os.path.join(result_root_path, "models", "state_dict_0")

# loading data
train_conf = load_train_conf(MLP_CONF_PATH)
sdmp_conf = load_train_conf(SDMP_CONF_PATH)
print(train_conf)
g = load_data(train_conf["name"])

GNN_MODEL_PATH = os.path.join(TARGET_GNN_FOLDER, sdmp_conf['target_h_model_path'])
GNN_CONF_PATH = os.path.join(TARGET_GNN_FOLDER, sdmp_conf['target_h_model_conf_path'])
GNN_ACC_PATH = os.path.join(TARGET_GNN_FOLDER, sdmp_conf['target_h_model_metric_path'])

preprocesser = SDMPDataPre(train_conf["name"], sdmp_conf["feature_normalize"],
                           sdmp_conf["target_h_mode"],
                           GNN_CONF_PATH, GNN_MODEL_PATH, sdmp_conf["target_h_model"], 
                           sdmp_conf["h_init_theta_mode"], sdmp_conf["h_init_theta_k"],
                           sdmp_conf["h_init_theta_k_fanout"],
                           sdmp_conf["theta_cand_mode"], sdmp_conf["theta_cand_k2"],
                           sdmp_conf["theta_cand_k1"], sdmp_conf["theta_cand_fanout"],
                           sdmp_conf["theta_cand_add_self"],
                           use_cache=True, cache_path=os.path.join(DATA_FOLDER, "SDMPPre"),
                           device=DEVICE)
preprocesser.disp_states()
theta_cand, h_init_theta, X, target = preprocesser.theta_cand, preprocesser.h_init_theta, preprocesser.X, preprocesser.target

# construct the SDMP
print("Initializing the model...")
sdmp = SDMP(X,
            target,
            theta_cand,
            h_init_theta,
            sdmp_conf,
            device=DEVICE,
            verbose=True)
sdmp.load(result_root_path, log_name="dict_log.pkl")

# construct the MLP
in_size = list(sdmp.h.children())[0][-1].out_features
hidden_size = [train_conf['hidden_size']] * train_conf['hidden_layer']
out_size = g.num_classes
model = MLP(in_size, hidden_size, out_size, dropout=train_conf['dropout']).to(DEVICE)
with open(MLP_STATE_PATH, 'rb') as fin:
    model.load_state_dict(pickle.load(fin))

import torch.nn.functional as F
#  main test
def test_inference_time(model, sdmp, device, g, eval_size = 100):
    sampling_time, infer_time = [], []
    val_idx = g.val_idx[all_ind[:eval_size]].to("cpu")
    tic_data_loader = time()
    val_dataloader = PlainLoader(torch.from_numpy(np.arange(g.num_nodes())),
                                 g.ndata['label'].to("cpu"),
                                 1, val_idx)
    print("data loader created in {:.6f} seconds".format(time() - tic_data_loader))
    model.eval()
    sdmp.efficient_prepare()
    with torch.no_grad():
        total_loss = 0
        y_hat_list, y_list = [], []
        tic_sampling = time()
        for it, (x, y) in enumerate(val_dataloader):
            # feat = sdmp.infer_torch_node_approximal_features_idx(x)
            feat = sdmp.efficient_node_wise_infer(x)
            sampling_time.append(time() - tic_sampling)
            tic_infer = time()
            y = y.to(device)
            y_hat = model(feat)
            infer_time.append(time() - tic_infer)
            loss = F.cross_entropy(y_hat, y)
            total_loss += loss.item()
            y_hat_list.append(y_hat)
            y_list.append(y)
            torch.cuda.empty_cache()
            tic_sampling = time()
        # end of for
        y_hat = torch.cat(y_hat_list)
        y = torch.cat(y_list)
        metric = torch_f1(y_hat, y)
    print(f"f1 score {metric}")
    return sampling_time, infer_time

all_ind = list(range(g.val_idx.size()[0]))
random.seed(666)
random.shuffle(all_ind)

all_sample_time, all_infer_time = [], []

sampling_time, infer_time = test_inference_time(model,sdmp, DEVICE, g)
all_sample_time.append(sampling_time)
all_infer_time.append(infer_time)

# final string generation

def gen_res_table(stime, itime, func_list, func_name_list, head_list):
    def apply_func(time_list, func):
        return [func(i[1:]) for i in time_list]
    rows = []
    # first row
    rows.append("\hline")
    cur_row = "Neighbor size & " + " & ".join(head_list) + " \\\\ \hline"
    rows.append(cur_row)
    
    for name, func in zip(func_name_list, func_list):
        svals = apply_func(stime, func)
        ivals = apply_func(itime, func)
        all_list = [[a+b for a, b in zip(s, i)] for s, i in zip(stime, itime)]
        alls = apply_func(all_list, func)
        comb = ["{:.5f}/{:.5f}/{:.5f}".format(i*1000, j*1000, k*1000) for i, j, k in zip(svals, ivals, alls)]
        cur_row = name + " & " + " & ".join(comb) + " \\\\"
        rows.append(cur_row)
        
    table_str = "\n".join(rows)
    
    table_str += " \hline"
    return table_str
        
func_list = [np.mean, lambda x: np.percentile(x, 90), lambda x: np.percentile(x, 99), np.max, np.std]
func_name_list = ["mean", "90-percentile", "99-percentile", "max", "std"]
head_list = ["MLP"]

res_str = gen_res_table(all_sample_time, all_infer_time, func_list, func_name_list, head_list)
print(res_str)